{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "beObUOFyuRjT"
      },
      "source": [
        "##### Copyright 2018 The TF-Agents Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "nQnmcm0oI1Q-"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s6D70EeAZe-Q"
      },
      "source": [
        "# 드라이버\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/agents/tutorials/4_drivers_tutorial\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">TensorFlow.org에서 보기</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ko/agents/tutorials/4_drivers_tutorial.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colab에서 실행하기</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ko/agents/tutorials/4_drivers_tutorial.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub에서 소스 보기</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ko/agents/tutorials/4_drivers_tutorial.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">노트북 다운로드하기</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8aPHF9kXFggA"
      },
      "source": [
        "## 소개\n",
        "\n",
        "강화 학습의 일반적인 패턴은 지정된 수의 단계 또는 에피소드에 대해 환경에서 정책을 실행하는 것입니다. 이는 예를 들어 데이터 수집, 평가 및 에이전트의 비디오 생성 중에 발생합니다.\n",
        "\n",
        "Python으로 작성하는 것이 비교적 간단하지만, `tf.while` 루프, `tf.cond` 및 `tf.control_dependencies`가 포함되므로 TensorFlow에서 작성하고 디버깅하는 것이 훨씬 더 복잡합니다. 따라서 run 루프라는 개념을 `driver`라는 클래스로 추상화하고 Python 및 TensorFlow에서 잘 테스트된 구현을 제공합니다.\n",
        "\n",
        "또한, 각 단계에서 드라이버가 발견한 데이터는 Trajectory라는 명명된 튜플에 저장되고 재현 버퍼 및 메트릭과 같은 observer 세트로 브로드캐스팅됩니다. 이 데이터에는 환경의 관찰 값, 정책에서 권장하는 행동, 획득한 보상, 현재 유형 및 다음 단계 등이 포함됩니다."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t7PM1QfMZqkS"
      },
      "source": [
        "## 설정"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0w-Ykwl1bn4v"
      },
      "source": [
        "tf-agents 또는 gym을 아직 설치하지 않은 경우, 다음을 실행합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TnE2CgilrngG"
      },
      "outputs": [],
      "source": [
        "!pip install tf-agents\n",
        "!pip install gym"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "whYNP894FSkA"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "\n",
        "from tf_agents.environments import suite_gym\n",
        "from tf_agents.environments import tf_py_environment\n",
        "from tf_agents.policies import random_py_policy\n",
        "from tf_agents.policies import random_tf_policy\n",
        "from tf_agents.metrics import py_metrics\n",
        "from tf_agents.metrics import tf_metrics\n",
        "from tf_agents.drivers import py_driver\n",
        "from tf_agents.drivers import dynamic_episode_driver\n",
        "\n",
        "tf.compat.v1.enable_v2_behavior()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9V7DEcB8IeiQ"
      },
      "source": [
        "## Python 드라이버\n",
        "\n",
        "`PyDriver` 클래스는 Python environment, Python policy 및 각 단계에서 업데이트할 observer의 목록을 사용합니다. 기본 메서드는 `run()`이며, 다음 종료 기준 중 하나 이상이 충족될 때까지 정책의 행동을 사용하여 환경을 단계화합니다. 종료 기준은 단계 수가 `max_steps`에 도달하거나 에피소드 수가 `max_episodes`에 도달하는 것입니다.\n",
        "\n",
        "구현은 대략 다음과 같습니다.\n",
        "\n",
        "```python\n",
        "class PyDriver(object):\n",
        "\n",
        "  def __init__(self, env, policy, observers, max_steps=1, max_episodes=1):\n",
        "    self._env = env\n",
        "    self._policy = policy\n",
        "    self._observers = observers or []\n",
        "    self._max_steps = max_steps or np.inf\n",
        "    self._max_episodes = max_episodes or np.inf\n",
        "\n",
        "  def run(self, time_step, policy_state=()):\n",
        "    num_steps = 0\n",
        "    num_episodes = 0\n",
        "    while num_steps < self._max_steps and num_episodes < self._max_episodes:\n",
        "\n",
        "      # Compute an action using the policy for the given time_step\n",
        "      action_step = self._policy.action(time_step, policy_state)\n",
        "\n",
        "      # Apply the action to the environment and get the next step\n",
        "      next_time_step = self._env.step(action_step.action)\n",
        "\n",
        "      # Package information into a trajectory\n",
        "      traj = trajectory.Trajectory(\n",
        "         time_step.step_type,\n",
        "         time_step.observation,\n",
        "         action_step.action,\n",
        "         action_step.info,\n",
        "         next_time_step.step_type,\n",
        "         next_time_step.reward,\n",
        "         next_time_step.discount)\n",
        "\n",
        "      for observer in self._observers:\n",
        "        observer(traj)\n",
        "\n",
        "      # Update statistics to check termination\n",
        "      num_episodes += np.sum(traj.is_last())\n",
        "      num_steps += np.sum(~traj.is_boundary())\n",
        "\n",
        "      time_step = next_time_step\n",
        "      policy_state = action_step.state\n",
        "\n",
        "    return time_step, policy_state\n",
        "```\n",
        "\n",
        "이제 CartPole 환경에서 임의의 정책을 실행하여 결과를 재현 버퍼에 저장하고 일부 메트릭을 계산하는 예를 살펴보겠습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Dj4_-77_5ExP"
      },
      "outputs": [],
      "source": [
        "env = suite_gym.load('CartPole-v0')\n",
        "policy = random_py_policy.RandomPyPolicy(time_step_spec=env.time_step_spec(), \n",
        "                                         action_spec=env.action_spec())\n",
        "replay_buffer = []\n",
        "metric = py_metrics.AverageReturnMetric()\n",
        "observers = [replay_buffer.append, metric]\n",
        "driver = py_driver.PyDriver(\n",
        "    env, policy, observers, max_steps=20, max_episodes=1)\n",
        "\n",
        "initial_time_step = env.reset()\n",
        "final_time_step, _ = driver.run(initial_time_step)\n",
        "\n",
        "print('Replay Buffer:')\n",
        "for traj in replay_buffer:\n",
        "  print(traj)\n",
        "\n",
        "print('Average Return: ', metric.result())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X3Yrxg36Ik1x"
      },
      "source": [
        "## TensorFlow 드라이버\n",
        "\n",
        "또한, TensorFlow에는 기능적으로 Python 드라이버와 유사한 드라이버가 있지만, TF environments, TF policies, TF observers 등을 사용합니다. 현재 두 개의 TensorFlow 드라이버, 주어진 수의 (유효한) 환경 단계 후 종료하는 `DynamicStepDriver` 및 주어진 수의 에피소드 후에 종료하는 `DynamicEpisodeDriver`가 있습니다. 동작 중인 DynamicEpisode의 예를 살펴보겠습니다.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WC4ba3ObSceA"
      },
      "outputs": [],
      "source": [
        "env = suite_gym.load('CartPole-v0')\n",
        "tf_env = tf_py_environment.TFPyEnvironment(env)\n",
        "\n",
        "tf_policy = random_tf_policy.RandomTFPolicy(action_spec=tf_env.action_spec(),\n",
        "                                            time_step_spec=tf_env.time_step_spec())\n",
        "\n",
        "\n",
        "num_episodes = tf_metrics.NumberOfEpisodes()\n",
        "env_steps = tf_metrics.EnvironmentSteps()\n",
        "observers = [num_episodes, env_steps]\n",
        "driver = dynamic_episode_driver.DynamicEpisodeDriver(\n",
        "    tf_env, tf_policy, observers, num_episodes=2)\n",
        "\n",
        "# Initial driver.run will reset the environment and initialize the policy.\n",
        "final_time_step, policy_state = driver.run()\n",
        "\n",
        "print('final_time_step', final_time_step)\n",
        "print('Number of Steps: ', env_steps.result().numpy())\n",
        "print('Number of Episodes: ', num_episodes.result().numpy())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Sz5jhHnU0fX1"
      },
      "outputs": [],
      "source": [
        "# Continue running from previous state\n",
        "final_time_step, _ = driver.run(final_time_step, policy_state)\n",
        "\n",
        "print('final_time_step', final_time_step)\n",
        "print('Number of Steps: ', env_steps.result().numpy())\n",
        "print('Number of Episodes: ', num_episodes.result().numpy())"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "4_drivers_tutorial.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
